Skip to content

[quantization] Introduce wrapper for Qwen3VLTextRotaryEmbedding#498

Open
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_text_rotary_embed
Open

[quantization] Introduce wrapper for Qwen3VLTextRotaryEmbedding#498
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_text_rotary_embed

Conversation

@dvsav
Copy link
Contributor

@dvsav dvsav commented Feb 17, 2026

This change introduces QuantQwen3VLTextRotaryEmbedding wrapper to support post-training quantization of Qwen3VLTextRotaryEmbedding module.

Why?

Qwen3VLTextRotaryEmbedding module is used in the language model of Qwen.
Trying to quantize Qwen3VLTextRotaryEmbedding via PTQ generates exception PTQQuantizer: no quantization wrapper for Qwen3VLTextRotaryEmbedding.

What

This change introduces:

  • Class QuantQwen3VLTextRotaryEmbedding (tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py).
  • Unit tests: class TestQuantQwen3VLTextRotaryEmbedding (test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py) - skipped if transformers package is not installed.
  • New entry tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding in _CORE_MODULES (tico/quantization/wrapq/wrappers/registry.py).
  • Example of Qwen3VLTextRotaryEmbedding quantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_text_rotary_embedding.py).

Unit Tests

Unit tests results with coverage information:

$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py -v
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 12 items                                                                                                                                                                                 

test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_activation_stats_collected PASSED                                    [  8%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_cos_sin_relationship PASSED                                          [ 16%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_different_batch_sizes PASSED                                         [ 25%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_different_sequence_lengths PASSED                                    [ 33%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_dtype_override PASSED                                                [ 41%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_mode_transitions PASSED                                              [ 50%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_no_learnable_parameters PASSED                                       [ 58%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_observer_count PASSED                                                [ 66%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_output_range PASSED                                                  [ 75%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_output_shape PASSED                                                  [ 83%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_quantised_output_close PASSED                                        [ 91%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py::TestQuantQwen3VLTextRotaryEmbedding::test_registration_in_registry PASSED                                      [100%]

================================================================================== 12 passed, 2 warnings in 6.66s ==================================================================================```

Coverage info (irrelevant files skipped):
```sh
$ coverage report -m
Name                                                                      Stmts   Miss  Cover   Missing
-------------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py      52      0   100%
...
-------------------------------------------------------------------------------------------------------
TOTAL                                                                     10300   6628    36%

@dvsav
Copy link
Contributor Author

dvsav commented Feb 18, 2026

For Reviewers

Below is the source code of Qwen3VLTextRotaryEmbedding module that can be used to check the correctness of QuantQwen3VLTextRotaryEmbedding implementation:

# transformers/models/qwen3_vl/modeling_qwen3_vl.py

class Qwen3VLTextRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: Qwen3VLTextConfig, device=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config

        self.rope_type = self.config.rope_parameters["rope_type"]
        rope_init_fn: Callable = self.compute_default_rope_parameters
        if self.rope_type != "default":
            rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)

        self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])

    @staticmethod
    def compute_default_rope_parameters(
        config: Qwen3VLTextConfig | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
    ) -> tuple["torch.Tensor", float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        base = config.rope_parameters["rope_theta"]
        dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
        )
        return inv_freq, attention_factor

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        # In contrast to other models, Qwen3VL has different position ids for the grids
        # So we expand the inv_freq to shape (3, ...)
        if position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
            freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

    def apply_interleaved_mrope(self, freqs, mrope_section):
        """Apply interleaved MRoPE to 3D rotary embeddings.
        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
        interleaved [THWTHWTHW...TT], preserving frequency continuity.
        args:
            x: (3, bs, seq_len, head_dim // 2)
            mrope_section: (3,)
        returns:
            x_t: (bs, seq_len, head_dim // 2)
        """
        freqs_t = freqs[0]  # just overwrite the first dimension T
        for dim, offset in enumerate((1, 2), start=1):  # H, W
            length = mrope_section[dim] * 3
            idx = slice(offset, length, 3)
            freqs_t[..., idx] = freqs[dim, ..., idx]
        return freqs_t

This change introduces QuantQwen3VLTextRotaryEmbedding wrapper to support post-training quantization of Qwen3VLTextRotaryEmbedding module.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
@dvsav dvsav force-pushed the quant_text_rotary_embed branch from d2f6509 to 6a07e4b Compare February 18, 2026 08:50
Comment on lines +127 to +212
def apply_interleaved_mrope(self, freqs, mrope_section):
"""
Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THWTHWTHW...TT], preserving frequency continuity.

Args:
freqs: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)

Returns:
freqs_t: (bs, seq_len, head_dim // 2)

Design Note:
This implementation is using slice_copy, index_select, and cat
to avoid yet unsupported slice_scatter with step=3 operation and
to avoid unsupported in-place operator index_put.default.
"""
# Start with T dimension (will keep some, replace some)
freqs_t_base = freqs[0]

# For each dimension (H, W), extract frequency bands to be interleaved
h_w_bands = []

for dim, offset in enumerate((1, 2), start=1): # H, W dimensions
length = mrope_section[dim] * 3
indices = torch.arange(offset, length, 3, device=freqs.device)

# Select frequency bands from H/W dimensions
# freqs[dim] has shape (bs, seq_len, head_dim//2)
# index_select on last dim: (bs, seq_len, num_selected)
freqs_bands = freqs[dim].index_select(dim=-1, index=indices)
h_w_bands.append(freqs_bands)

# Now we need to build the interleaved output
# Original T dimension has indices 0-63
# We want to replace specific indices with H/W bands

# The interleaving pattern: T0, H1, W2, T3, T4, H5, W6, T7, ...
# Where T, H, W bands follow the pattern from mrope_section

# Build the output by slicing and concatenating
# Strategy: Slice T dimension into chunks, insert H/W bands, concatenate

chunks = []
pos = 0

# Total length in the last dimension
total_len = freqs_t_base.shape[-1]

for i in range(total_len):
# Determine which dimension this position belongs to
# Pattern: T, H, W, T, T, H, W, T, ...
mod = i % 3

if mod == 0:
# T dimension position - take from T
# Slice just this index from T
chunk = freqs_t_base[..., i : i + 1]
chunks.append(chunk)
elif mod == 1:
# H dimension position - take from H
# Calculate which band this is
band_idx = (i - 1) // 3
if band_idx < h_w_bands[0].shape[-1]:
chunk = h_w_bands[0][..., band_idx : band_idx + 1]
chunks.append(chunk)
else:
# Fallback to T if out of bounds
chunk = freqs_t_base[..., i : i + 1]
chunks.append(chunk)
else: # mod == 2
# W dimension position - take from W
band_idx = (i - 2) // 3
if band_idx < h_w_bands[1].shape[-1]:
chunk = h_w_bands[1][..., band_idx : band_idx + 1]
chunks.append(chunk)
else:
# Fallback to T if out of bounds
chunk = freqs_t_base[..., i : i + 1]
chunks.append(chunk)

# Concatenate all chunks
freqs_t = torch.cat(chunks, dim=-1)

return freqs_t
Copy link
Contributor Author

@dvsav dvsav Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for Reviewers

Trying to replicate the original implementation Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope leads to errors at the time of conversion to Circle.
The original implementation uses slice(offset, length, 3) that emits slice_scatter operator with step=3 when the model is exported.
When it's being converted to Circle DecomposeSliceScatter pass fails with the following error: RuntimeError: slice_scatter with step > 1 is not yet supported. Node: slice_scatter.
Approaches leveraging in-place operations don't work either.
For example:

# Create list of indices manually (avoid step=3 in slice)
idx = list(range(offset, length, 3))

# Extract and copy using cat
freqs_t[..., idx] = freqs[dim, ..., idx]

This approach fails because it emits index_put operator that isn't supported in Circle (tico/utils/convert.py raises tico.utils.errors.NotYetSupportedError: NOT SUPPORTED OPERATOR IN GRAPH MODULE).
The same goes for the following example (using index_copy_ that generates index_put as well):

# Create tensor of indices using torch.arange
# This avoids Python list which causes index_put
indices = torch.arange(offset, length, 3, device=freqs.device)

# Select from source dimension
# freqs has shape (3, batch, seq_len, head_dim//2)
# Select all batch and seq dims, only specific indices in head_dim dim
src_selected = freqs[dim].index_select(dim=-1, index=indices)

# Copy to target using index_copy_ (which is supported)
freqs_t.index_copy_(dim=-1, index=indices, source=src_selected)

The only solution that worked was a pure functional approach using basic tensor slicing (converted to slice_copy operators during torch.export), index_select, and cat.

Why in-place operations fail: any in-place tensor update (like tensor[...] = value, index_copy_, or index_put_) during torch.export gets traced to an operator that mutates tensor memory. Circle (as well as TFLite) runtime model is designed for functional computations without in-place mutations, so these operators are not supported.

The functional approach works because it:

  • Builds intermediate tensors via slice_copy and index_select (read-only operations).
  • Combines them via cat (creates a new tensor, doesn't modify existing ones).
  • Never mutates tensors in-place, thus avoiding unsupported operators.

@dvsav dvsav marked this pull request as ready for review February 18, 2026 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments